import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D  # necessário para plotagem 3D

def euler_3d(r0, v0, acceleration, dt, t_max):
    """
    Integra o sistema usando o método de Euler para 3D.
    
    Parâmetros:
      r0         : posição inicial (vetor numpy: [x0, y0, z0])
      v0         : velocidade inicial (vetor numpy: [vx0, vy0, vz0])
      acceleration: função que recebe (r, v, t) e retorna a aceleração (ax, ay, az)
      dt         : passo de tempo (segundos)
      t_max      : tempo total de simulação (segundos)
      
    Retorna:
      t_array : array de tempos
      r_array : array de posições (cada linha é um vetor [x, y, z])
      v_array : array de velocidades (cada linha é um vetor [vx, vy, vz])
    """
    # Criar arrays de tempo e listas para armazenar as trajetórias
    t_array = np.arange(0, t_max + dt, dt)
    r_list = [r0]
    v_list = [v0]
    
    r = r0.copy()
    v = v0.copy()
    
    # Loop de integração via método de Euler
    for t in t_array[:-1]:
        # Calcula a aceleração atual
        a = acceleration(r, v, t)
        # Atualiza a velocidade
        v = v + a * dt
        # Atualiza a posição usando a velocidade (nota: aqui é possível usar v ou v + a*dt se quiser uma aproximação mais precisa)
        r = r + v * dt
        
        r_list.append(r.copy())
        v_list.append(v.copy())
    
    return t_array, np.array(r_list), np.array(v_list)

# Exemplo de função de aceleração: somente gravidade no eixo z.
def acceleration(r, v, t):
    # Considera apenas a aceleração devido à gravidade
    ax = 0.0
    ay = 0.0
    az = -9.81  # m/s² (gravidade)
    return np.array([ax, ay, az])

# Condições iniciais
r0 = np.array([0.0, 2.0, 3.0])  # Posição inicial (x, y, z)
v0 = np.array([44.444, 5.556, -5.556])  # Velocidade inicial (vx, vy, vz) em m/s

# Parâmetros da simulação
dt = 0.01   # passo de tempo
t_max = 5.0  # tempo máximo da simulação

# Executa a integração
t_array, r_array, v_array = euler_3d(r0, v0, acceleration, dt, t_max)

# Plot da trajetória 3D
fig = plt.figure(figsize=(10, 6))
ax = fig.add_subplot(111, projection='3d')
ax.plot(r_array[:, 0], r_array[:, 1], r_array[:, 2], 'b-', label='Trajetória')
ax.set_xlabel('X (m)')
ax.set_ylabel('Y (m)')
ax.set_zlabel('Z (m)')
ax.set_title("Trajetória 3D usando o Método de Euler")
ax.legend()
plt.show()
